Skip to content

[Pytorch][Common] Hybrid quantization#2817

Open
negvet wants to merge 25 commits into
NVIDIA:mainfrom
negvet:hybrid_quantization
Open

[Pytorch][Common] Hybrid quantization#2817
negvet wants to merge 25 commits into
NVIDIA:mainfrom
negvet:hybrid_quantization

Conversation

@negvet

@negvet negvet commented Mar 31, 2026

Copy link
Copy Markdown
Collaborator

Description

Hybrid (per-direction) quantization. Functional.
C++ optimizations (fusions, etc.) will come in the next PRs.

TODO:

  • Double quantization
  • Non-hybrid convergence of base recipes (validation)
  • DCP/torch_dist doesn't preserve hybrid current-scaling primary-weight _scale_inv; fix via __tensor_flatten__/__tensor_unflatten__ across the tensor stack — torch.save/load and FP32-master paths unaffected; covered by the test_hybrid_dcp_output_parity xfail.

Integration

Ecosystem integration (all functional, unit-tested):

  • [Done] quantized_model_init
  • [Done] FSDP2 (TODO: optimize communication buffers)
  • [Done] CPU offloading
  • [Done] Activation recomputation
  • [Done] TP/SP (TODO: enable quantized AG)

Megatron-LM integration status:

  • [Done] 1 GPU baseline
  • [Done] DP + distributed optimizer
  • [TODO] quantized_model_init + --fp{4,8}-param-gather + dist opt (persistent low-precision params via quantized_model_init + sharded-master FP32 → quantized cast via quantize_master_weights.)
    - [Done] Per-tensor Float8 hybrid (delayed and/or current, any per-direction combination
    including same-format, cross-format Float8, single-direction)
    - [TODO] Per-block hybrid sub-quantizers (MXFP8, NVFP4, Float8Blockwise) — each rejected per-direction by quantize_master_weights; unblocker is TE-side cast-helper / kernel.
  • [TODO] Megatron-FSDP + --fp{4,8}-param-gather (fix private attribute access)
  • [TODO] Torch FSDP2 + --fp{4,8}-param-gather
    - [Done] TE-side hybrid FSDP2 path works end-to-end for Float8 / MXFP8 / Float8Blockwise sub-storages (TODO: need some minor MLM update)
    - [TODO] NVFP4 sub-storage FSDP2 hooks
  • [Done] Activation recompute
  • [Done] CPU offload
  • [Done] TP/SP/PP
  • [Done] MoE + EP + grouped GEMM (qwen3 MoE; _hybrid_split_quantize under Megatron MoE)

Review

Total diff +9000
New hybrid source (hybrid_tensor.py, hybrid_tensor_storage.py) ~1000
Adjacent modifications ~1000
Tests are the rest

Surface to review is ~2000 lines

Suggested reading order

  1. Foundation — 7553e6a: Python containers + quantize/gemm dispatch/unwrap
  • tensor/hybrid_tensor.py — HybridQuantizer + HybridQuantizedTensor
  • tensor/storage/hybrid_tensor_storage.py
  • cpp_extensions/gemm.py — _unwrap_hybrid_A/B
  • common/transpose/quantize_transpose_square_blockwise.cu - Block FP8 columnwise-only null-checks
  • Module hooks in module/{base,grouped_linear,layernorm_linear,layernorm_mlp}.py
  • Tests: TestHybridQuantizer*, TestHybridGemmBitwiseIdentical* (proves zero-overhead vs vanilla recipes when both formats match), TestHybridDirectionUnwrap*, TestHybridGroupedLinear*
  1. quantized_model_init + FusedAdam — f80f5d0
  • hybrid_tensor.py::HybridQuantizer.update_quantized — delegates to each sub-quantizer; unblocks workspace-cache quantize_() and FusedAdam writeback
  • module/base.py workspace-cache invalidation
  • Tests: TestHybridQuantizedModelInit, TestHybridFusedAdam, TestHybridQuantizedParamsEndToEnd, TestHybridCheckpoint, TestQuantizedParamsEquivalence*
  1. FSDP2 support — 2185b30
  • New base FSDP2 buffer protocol on QuantizedTensorStorage: fsdp_buffer_fields / fsdp_extract_buffers / fsdp_assign_gathered. Generic, reusable beyond hybrid.
  • Per-format overrides on Float8TensorStorage (direction-aware) and MXFP8TensorStorage (trips/re-applies scale alignment padding around the gather)
  • hybrid_tensor.py::fsdp_pre/post_all_gather + torch_dispatch for the FSDP2 op set (view, split, as_strided, slice, copy_, new_zeros, clone, detach)
  • Non-safety in float8_tensor.py and mxfp8_tensor.py for single-direction sub-storages (columnwise-only on Hopper/L40)
  • Tests: TestHybridTorchDispatchFSDP2Ops, TestHybridFsdpPreAllGatherProtocol, TestHybridFsdpRoundtrip (bitwise-exact against manual all_gather(dequantize(shard))), plus tests/pytorch/distributed/fsdp2_tests/
  1. CPU offloading — 103fffe
  • hybrid_tensor_storage.py::clear() (v1 path) + prepare_for_saving / restore_from_saved chain (v2 path)
  • hybrid_tensor.py::detach() re-wraps each sub-storage via make_like (required by cpu_offload_v2's detach → prepare_for_saving pattern; sharing sub-storage objects would null-out fields on the original)
  • TestHybridCpuOffloadPushPop, plus updates to test_cpu_offloading*.py
  1. Activation recomputation — 16fb371
  • Uses existing QuantizedTensorStorage::prepare_for_saving / restore_from_saved protocol, preserving ordering across both sub-storages
  • Tests: 20 bitwise tests in TestHybridActivationRecompute
  1. TP/SP — a50fd63
  • hybrid_tensor.py::HybridQuantizer.supports_only_rowwise_all_gather — overrides to handle the NVFP4 columnwise-dequantize gap in the BF16 fallback path
  • distributed.py::gather_along_first_dim — hybrid branch re-quantizes with both directions after AG (since hybrid has no _create_transpose synthesis path)
  • Tests: 9 distributed tests in run_hybrid_tp_sp.py / test_hybrid_tp_sp.py
  1. Megatron-LM integration — a164cd3
  • tensor/utils.py::_route_hybrid_to_buckets — per-direction dispatch for quantize_master_weights: iterates both sub-storages, routes each independently into the per-format bucket matching its own sub-quantizer type
  • Hybrid branches in replace_raw_data and post_all_gather_processing
  • Today: per-tensor Float8 sub-quantizers (delayed + current) work in any per-direction combination. Per-block sub-quantizers raise per-direction with in-code TODOs naming the unblocker.
  • Tests: TestHybridQuantizeMasterWeights, TestHybridPostAllGatherProcessing

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps

greptile-apps Bot commented Mar 31, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces hybrid (per-direction) quantization for PyTorch, allowing rowwise and columnwise GEMM operands to use different quantization formats (e.g., MXFP8 rowwise + NVFP4 columnwise). The change adds HybridQuantizer, HybridQuantizedTensor, HybridQuantizedTensorStorage, and IdentityQuantizer/IdentityTensor as new first-class types, wired into the module, FSDP2, TP/SP, CPU offloading, activation recompute, distributed optimizer, and grouped GEMM paths.

  • Core containers (hybrid_tensor.py, hybrid_tensor_storage.py): two-pass quantization dispatch, full __torch_dispatch__ coverage for FSDP2 ops, direction-aware fsdp_pre/post_all_gather, detach/clone/pickling. Previously flagged issues (try/finally in make_empty, repr NoneType, mixed-quantizer crash, None-entry list, fsdp_post_all_gather update_usage gap, direction-aware fsdp_buffer_fields) are all addressed.
  • GEMM dispatch (cpp_extensions/gemm.py): _unwrap_hybrid_A/B route each operand to the correct sub-storage based on layout transpose flags; _materialize_high_precision handles IdentityTensorStorage passthrough.
  • Distributed optimizer (tensor/utils.py): _route_hybrid_to_buckets decomposes a hybrid weight into per-direction entries routed into the existing delayed/current-scaling/identity cast buckets; per-block sub-quantizers (MXFP8, NVFP4, Float8Block) raise with clearly documented TODOs and unblocker shapes.

Confidence Score: 5/5

Safe to merge; all previously flagged blocking issues are addressed and the new code is well-structured with extensive tests and inline documentation of known gaps.

The core hybrid containers, FSDP2 protocol, GEMM dispatch, distributed optimizer routing, and CPU offload paths are cleanly implemented. Previously flagged issues (exception safety in make_empty, repr NoneType, mixed-quantizer crash, None-entry list handling, fsdp_post_all_gather update_usage gap, direction-aware fsdp_buffer_fields) are all addressed in this revision. The remaining observations are non-critical: _hybrid_split_quantize quantizes both directions regardless of inference-mode usage flags (memory waste, no wrong output), and the MXFP8 split degenerate-None path is equivalent to the original code's failure mode. Known gaps (DCP, NVFP4 FSDP2, per-block distopt) are clearly documented as xfail tests and inline TODOs.

transformer_engine/pytorch/module/grouped_linear.py (_hybrid_split_quantize usage-flag handling) and transformer_engine/pytorch/tensor/mxfp8_tensor.py (split degenerate-None guard).

Important Files Changed

Filename Overview
transformer_engine/pytorch/tensor/hybrid_tensor.py New HybridQuantizer + HybridQuantizedTensor implementation; includes FSDP2 pre/post all-gather, aten dispatch for view/split/copy_/clone/new_zeros/as_strided, detach, pickling, and TP/SP supports_only_rowwise_all_gather. Well-documented and carefully handles direction-specific edge cases.
transformer_engine/pytorch/tensor/storage/hybrid_tensor_storage.py New HybridQuantizedTensorStorage; correctly implements prepare_for_saving/restore_from_saved, clear, view, dequantize, size, and get_metadata. update_usage is intentionally one-way (drop-only), consistent with how callers use it.
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py Adds direction-aware fsdp_buffer_fields and fsdp_assign_gathered overrides; correctly handles columnwise-only sub-storages (_transpose field) and clears _transpose_invalid after gathered write-back.
transformer_engine/pytorch/tensor/float8_tensor.py Adds None-safety for _data in clone() and aten.split dispatch; correctly inverts transpose shape (*shape[1:], shape[0]) when _data is None for columnwise-only sub-storages.
transformer_engine/pytorch/module/grouped_linear.py Adds Identity and hybrid quantizer dispatch paths via _hybrid_split_quantize/_split_quantize_with_identity_fallback; _hybrid_split_quantize does not respect parent HybridQuantizer usage flags, causing redundant columnwise quantization during inference.
transformer_engine/pytorch/tensor/utils.py Adds _route_hybrid_to_buckets, _cast_master_weights_to_identity, and hybrid branches in post_all_gather_processing; _cast_master_weights_to_identity validates start_offset but silently ignores it in the FSDP shard path.
transformer_engine/pytorch/cpp_extensions/gemm.py Adds _unwrap_hybrid_A/B and _materialize_high_precision helpers; correctly routes hybrid GEMM operands based on layout transpose flags for all four GEMM layouts (TN/NN/NT/TT).
transformer_engine/pytorch/tensor/identity_tensor.py New IdentityQuantizer + IdentityTensor for high-precision passthrough; correctly implements the full QuantizedTensorStorage protocol including FSDP2 buffer fields, prepare_for_saving, and dequantize.
transformer_engine/pytorch/tensor/mxfp8_tensor.py Adds None-safety for _rowwise_data in clone() and refactors split dispatch; new ref_splits=None crash path when both data fields are None is a degenerate-case regression from the original's zip(*out_data) TypeError.
transformer_engine/pytorch/distributed.py Adds HybridQuantizer branch in gather_along_first_dim that restores usage flags via try/finally; correctly re-quantizes both directions post-AG until native hybrid AG dispatch is implemented.
transformer_engine/pytorch/module/base.py Adds HybridQuantizer to bgrad_quantize bypass, workspace validity check for HybridQuantizedTensorStorage, and amax_reduction_group propagation; changes are clean and well-commented.

Class Diagram

%%{init: {'theme': 'neutral'}}%%
classDiagram
    class Quantizer {
        +rowwise_usage: bool
        +columnwise_usage: bool
        +quantize(tensor)
        +make_empty(shape)
        +update_quantized(src, dst)
    }

    class HybridQuantizer {
        +rowwise_quantizer: Quantizer
        +columnwise_quantizer: Quantizer
        +supports_only_rowwise_all_gather()
        +quantize_impl(tensor)
        +update_quantized(src, dst)
    }

    class IdentityQuantizer {
        +dtype: Optional[torch.dtype]
        +quantize_impl(tensor)
    }

    class QuantizedTensorStorage {
        +fsdp_buffer_fields()
        +fsdp_extract_buffers()
        +fsdp_assign_gathered(gathered, meta)
        +prepare_for_saving()
        +restore_from_saved(tensors)
    }

    class HybridQuantizedTensorStorage {
        +_rowwise_storage: Optional[QTStorage]
        +_columnwise_storage: Optional[QTStorage]
        +_rowwise_quantizer
        +_columnwise_quantizer
        +update_usage(row, col)
        +clear()
        +dequantize()
        +view()
    }

    class HybridQuantizedTensor {
        +fsdp_pre_all_gather()
        +fsdp_post_all_gather()
        +detach()
        +__torch_dispatch__()
        +__reduce_ex__()
    }

    class IdentityTensorStorage {
        +_hp_data: torch.Tensor
        +fsdp_buffer_fields()
        +dequantize()
        +update_usage()
    }

    class Float8TensorStorage {
        +fsdp_buffer_fields()
        +fsdp_assign_gathered()
    }

    class MXFP8TensorStorage {
        +fsdp_buffer_fields()
        +fsdp_extract_buffers()
        +fsdp_assign_gathered()
    }

    Quantizer <|-- HybridQuantizer
    Quantizer <|-- IdentityQuantizer
    QuantizedTensorStorage <|-- HybridQuantizedTensorStorage
    QuantizedTensorStorage <|-- IdentityTensorStorage
    QuantizedTensorStorage <|-- Float8TensorStorage
    QuantizedTensorStorage <|-- MXFP8TensorStorage
    HybridQuantizedTensorStorage <|-- HybridQuantizedTensor
    HybridQuantizer --> HybridQuantizedTensorStorage : produces
    HybridQuantizedTensorStorage --> QuantizedTensorStorage : _rowwise_storage
    HybridQuantizedTensorStorage --> QuantizedTensorStorage : _columnwise_storage
    IdentityQuantizer --> IdentityTensorStorage : produces
Loading

Reviews (10): Last reviewed commit: "Bug fixing" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/tensor/storage/hybrid_tensor_storage.py
Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py Outdated

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think this moves us in a good direction. I see some minor bugs, as well as bugs reported by @greptile-apps.

Comment on lines +52 to +53
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we handle the case where not all usages are needed? I'd expect something like:

Suggested change
rowwise_result = self.rowwise_quantizer.quantize(tensor)
columnwise_result = self.columnwise_quantizer.quantize(tensor)
rowwise_result = self.rowwise_quantizer.quantize(tensor) if self.rowwise_usage else None
columnwise_result = self.columnwise_quantizer.quantize(tensor) if self.columnwise_usage else None

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

requires_grad: bool = False,
pin_memory: bool = False,
) -> HybridQuantizedTensor:
self.rowwise_quantizer.internal = True

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just set internal=True in the constructor? I don't think we ever need PyTorch tensor functionality in the per-usage data.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would not work under FSDP2.

Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py Outdated
Comment on lines +1339 to +1355
def factory(role):
if role == "linear_weight":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_mxfp8_quantizer(),
)
if role == "linear_input":
return HybridQuantizer(
rowwise_quantizer=_make_fp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
if role in ("linear_grad_output", "linear_grad_input"):
return HybridQuantizer(
rowwise_quantizer=_make_mxfp8_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
return None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is horrifying. Good test.

negvet and others added 10 commits April 6, 2026 10:26
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/tensor/hybrid_tensor.py
negvet and others added 2 commits April 29, 2026 16:02
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Comment thread transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py
negvet added 3 commits May 13, 2026 12:34
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
@negvet negvet requested a review from ksivaman as a code owner May 21, 2026 13:53
Comment on lines 665 to 677
outs = [
Float8Tensor.make_like(
tensor,
data=split_tensor,
data_transpose=split_transpose_tensor,
shape=split_tensor.shape,
shape=(
split_tensor.shape
if split_tensor is not None
else split_transpose_tensor.shape
),
)
for split_tensor, split_transpose_tensor in zip(func_out, t_func_out)
]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 When _data is None (columnwise-only sub-storage of a HybridQuantizedTensor on non-Hopper), the split falls back to split_transpose_tensor.shape, which is the transposed layout's shape [K, M/n]. The correct nominal shape for the shard is [M/n, K]. This wrong nominal shape propagates into the HybridQuantizedTensor through fsdp_post_all_gather (which calls _infer_shape on the gathered _transpose buffer to build col_sub), so after the first FSDP2 iteration the assembled full-parameter hybrid's _columnwise_storage reports [K, M] instead of [M, K]. Any Python-side code that calls .size() on that sub-storage (e.g., HybridQuantizedTensorStorage.size() when rowwise is also None, workspace-validity checks, debugging assertions) will see the wrong dimensions.

Suggested change
outs = [
Float8Tensor.make_like(
tensor,
data=split_tensor,
data_transpose=split_transpose_tensor,
shape=split_tensor.shape,
shape=(
split_tensor.shape
if split_tensor is not None
else split_transpose_tensor.shape
),
)
for split_tensor, split_transpose_tensor in zip(func_out, t_func_out)
]
outs = [
Float8Tensor.make_like(
tensor,
data=split_tensor,
data_transpose=split_transpose_tensor,
shape=(
split_tensor.shape
if split_tensor is not None
# _transpose has shape [K, M/n] but the shard's nominal shape
# is [M/n, K]. Recover the correct shard shape by reversing
# the last two dims of the transposed piece.
else (*split_transpose_tensor.shape[1:], split_transpose_tensor.shape[0])
),
)
for split_tensor, split_transpose_tensor in zip(func_out, t_func_out)
]

Comment on lines +27 to +30
# DCP serializes ``CustomRecipe`` via ``pickle``; closure-based qfactories
# (lambdas, inner functions referencing captured state) are not picklable,
# so the qfactory must live at module scope. See
# ``run_fsdp2_fused_adam.py::test_hybrid_dcp_output_parity``.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is potentially useful, but I don't think it is in the right place - shouldn't it be closer to the actual implementation?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines +1177 to +1184
for param in model.parameters():
state = optimizer.state[param]
assert state["exp_avg"].dtype == torch.float32
assert state["exp_avg_sq"].dtype == torch.float32
if "master_param" in state:
assert state["master_param"].dtype == torch.float32

assert losses[-1] < losses[0], f"Loss did not decrease: {losses[0]:.4f} -> {losses[-1]:.4f}"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not a very strict test, is there a way for us to do some numerical correctness comparisons?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enabled check for the monotonic loss decrease (still mostly sanity), and also enabled hybrid vs vanilla bitwise recipe comparizon, see e.g. test_fused_adam_hybrid_vs_base_recipe_parity.

# Quantized training may diverge from bf16, but should not be wildly different.
for step, (h_loss, b_loss) in enumerate(zip(hybrid_losses, bf16_losses)):
ratio = h_loss / max(b_loss, 1e-10)
assert 0.1 < ratio < 10.0, (

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm... What are the actual values there? Could we maybe set a seed or something and compare with some more reasonable tolerance?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a few % from bf16, tightened the tolerance, please take a look

"HybridMXFP8": dict(rtol=0.0, atol=0.0),
"HybridMixed_MXFP8_FP8": dict(rtol=0.0, atol=0.0),
}
tolerance = _TIGHT_TOLERANCE.get(hybrid_recipe_name, dict(rtol=1e-6, atol=1e-6))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, again, what are the actual values here? I understand the general idea here that when the format is stateful then you will have some difference, but I don't think that 1e-6 would be the right tolerance if that difference actually happened. So if we do not actually test the case that would exhibit this issue then maybe it would be better to just set the tolerances to 0 in all cases to simplify the test code? This would then also be a clear point of failure if somebody added here a recipe that would not exhibit this same property.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I agree, just setting to 0.

Comment on lines +1403 to +1404
assert per_rank_out % 32 == 0, "MXFP8 data alignment precondition"
assert per_rank_out % 128 != 0, "Test precondition: shard must need scale padding"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those error messages, since they are purely meant for the person changing the test itself, could be more descriptive in the error message.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines +1462 to +1482
if hybrid_recipe_name == "HybridFloat8BlockScaling":
pytest.xfail(
"HybridFloat8BlockScaling: Float8BlockwiseQTensor sub-storage loses "
"quantized type through FSDP2 view(-1)."
)

if hybrid_recipe_name == "HybridFP8CurrentScaling":
pytest.xfail(
"HybridFP8CurrentScaling: per-tensor _scale_inv is not preserved "
"through DCP's tensor-storage-byte serialization path. "
"HybridQuantizedTensor.__reduce_ex__ correctly round-trips through "
"pickle (verified by torch.save/torch.load), but DCP bypasses "
"pickle and serializes the tensor's storage bytes — the scalar "
"_scale_inv is not enumerated as a separate tensor leaf and gets "
"lost. Vanilla Float8CurrentScaling avoids this because per-tensor "
"scale lives in module.fp8_meta (saved as extra_state), not on "
"the tensor; hybrid uses per-sub-storage scales without that "
"mirror. Fix path: implement __tensor_flatten__/__tensor_unflatten__ "
"across the quantized tensor stack so DCP can serialize the "
"per-leaf tensor fields directly. Loaded model output diverges by "
"~5e-2."

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, do we intend to do something about that?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, a proper fix would touch all tensors, so let's do that in a separate PR, added a TODO at the description above.


def _build_hybrid_model(num_layers, hybrid_recipe, use_meta_device=True):
"""Build a model with quantized_model_init using a hybrid CustomRecipe."""
ctx = te.quantized_model_init(enabled=True, recipe=hybrid_recipe)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This quite strange choice (it was also in the other test files) to separate the ctx definition from the usage. It is not a big deal (basically a nit), but it looks strange - it is better to have things close to the actual usage site if they are not too big.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines +490 to +494
if hybrid_recipe_name == "HybridFloat8BlockScaling":
pytest.xfail(
"HybridFloat8BlockScaling: Float8BlockwiseQTensor sub-storage loses "
"quantized type through FSDP2 view(-1)."
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any plans to address it? Is it a limitation of the underlying recipe? @vthumbe1503 any thoughts here?

@vthumbe1503 vthumbe1503 Jun 3, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the QuantizedTensors have view(-1) implemented to return the dequantized output, so there should be something more going on in Float8BlockScaling that might be causing the test to fail here.

Also the view(-1) in FSDP2 is done to store a sharded tensor just for checkpointing logic and isnt used anywhere.

@negvet negvet Jun 3, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the insight @vthumbe1503

This is a hybrid-related bug.
After fixing the intermediate view-related bug, it turns out that Float8BlockwiseQTensorStorage does not implement the FSDP2 sub-storage protocol (fsdp_buffer_fields / fsdp_extract_buffers / fsdp_assign_gathered) required by hybrid all-gather.

Fixing it, WIP.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hybrid_avg = sum(hybrid_increments) / len(hybrid_increments)

excess_per_layer = hybrid_avg - bf16_avg
tolerance_per_layer = 50 * 1024 # 50 KiB

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the base of that tolerance? What are the actual values?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment that replies:

# Basis: forward growth is constant per layer (no accumulation) for both bf16 and
# hybrid; the excess is just hybrid's extra per-layer quantized buffers. Measured
# excess: ~3 KiB (FP8 current) / ~7 KiB (mixed MXFP8+FP8) / ~12 KiB (MXFP8). A
# leaked layer's quantized weights would be hundreds of KiB, so 50 KiB sits above
# the real per-layer overhead and well below a leak.

)

excess = hybrid_bwd_delta - bf16_bwd_delta
tolerance = 256 * 1024 # 256 KiB

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this one larger than the previous one?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

50 is a per-layer forward, whereas 256 is a whole-model backward + optimizer step

loss = F.mse_loss(output, target)
loss.backward()
optimizer.step()
dist_print(f"Hybrid iteration {iteration} completed with loss {loss.item()}")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this test actually checking? There are no assertions here - if we only check if it does not crash then what is the value of this test vs the other ones?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this was just a smoke test. Updated it to assert loss finiteness + strict monotonic decrease, hybrid-type preservation across the optimizer step, and FSDP2 all-gather correctness vs a manual fp32 dequant-then-allgather (the check test_distributed already had and this one was missing).

loss = F.mse_loss(output, target)
loss.backward()
optimizer.step()
dist_print(f"Hybrid reshard_after_fwd iter {iteration}, loss {loss.item():.4f}")

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

Comment on lines +18 to +25
* ``te.Linear`` column-parallel and row-parallel, with and without
sequence parallelism.
* ``te.LayerNormLinear`` column-parallel with sequence parallelism —
the quantized-AG path that currently unfuses LN+quantize for
``HybridQuantizer``.
* ``te.TransformerLayer`` with ``set_parallel_mode=True`` and SP on —
integration test hitting LayerNormLinear + DPA + LayerNormMLP + row-
parallel output projection in one shot.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering that Transformer layer gives basically everything, what is the value of the other tests? And if there is value in the other tests, then why don't we check the LayerNormMLP on its own?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other tests are complementary, not redundant. Transformer layer test is the broad smoke. Standalone tests have additional value (grad checks + extra configs). Following this, adding LayerNormMLP. Also added a bitwise hybrid-vs-vanilla equivalence test on te.Linear.

Comment on lines +29 to +31
numerical signal is clean. Cross-format hybrid adds independent
numerical variation unrelated to TP/SP and is covered by single-GPU
tests already.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I agree with this assessment. Cross format is actually the only case where you need to be careful about the allgather cases being different in forward and backward and allgather touches the comm-gemm overlap that would also be potentially affected (e.g. due to wrong buffer sizes taken from forward recipe being used in backward).

Also, a general comment - hybrid recipes that do forward quantized/backward unquantized and vice versa would be very useful to test as well.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, you are right, AG is format dependent. Adding MXFP8 forward / NVFP4 backward. and updating the docstring.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, a general comment - hybrid recipes that do forward quantized/backward unquantized and vice versa would be very useful to test as well.

Looking into it.

Comment on lines +33 to +36
Tolerances match upstream ``run_numerics.py`` per-format settings (see
``_get_tolerances``); they're loose enough to absorb the amax-reduction
and stochastic numerical asymmetries inherent to distributed FP8, tight
enough to catch a silent BF16 fallback on the hybrid sub-storage path.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the tolerances be effectively 0 if you are doing the non-actually-hybrid recipes only? Since you should be comparing the same underlying implementations.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tolerances are for the distributed vs single-node comparison, not hybrid vs vanilla. I will reword this comment to make the two comparisons distinct.

columnwise_quantizer=_make_mxfp8_quantizer(),
)
if is_linear and role.tensor_type in ("grad_output", "grad_input"):
return _make_mxfp8_quantizer(fp8_dtype=tex.DType.kFloat8E5M2)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E5M2 here is not correct. In general this should just be a single line to return the mxfp8 quantizer for all cases.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, fixed

Comment on lines +126 to +131
"""Default NVFP4Quantizer: no RHT, no stochastic rounding, no 2D
scaling — matches upstream ``run_numerics.py::nvfp4_vanilla()`` which
strips the recipe to bare 1D block scaling for distributed TP
fairness. Same-format hybrid NVFP4 has no E5M2 variant (NVFP4 is a
single format family — E2M1 only), so grad roles reuse the same
NVFP4 quantizer."""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we want to check the full recipe here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to the full recipe except 1D for weights, will enabled after #3027 merge

Comment on lines +136 to +143
is_linear = role is not None and role.module_type in ("linear", "grouped_linear")
if is_linear and role.tensor_type in ("input", "weight", "output"):
return HybridQuantizer(
rowwise_quantizer=_make_nvfp4_quantizer(),
columnwise_quantizer=_make_nvfp4_quantizer(),
)
if is_linear and role.tensor_type in ("grad_output", "grad_input"):
return _make_nvfp4_quantizer()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As written those lines are not needed at all. They would be needed if you did the full recipe.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to the full recipe

Comment on lines +166 to +168
# quantization (rowwise and columnwise quantizers run independently, so
# their outputs may differ by ~1 ULP from a single fused-quantize path
# in edge cases).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That does not sound like a good thing if it actually happens in practice - the quantization only should not be affected if you do both at the same time vs one at a time -> the input and the algorithm is the same in both cases. Fusion with the activations could maybe give slightly different results, but I would still like to get an explanation of why that would be.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. If the algorithm is the same, we are indeed getting identical results. New bitwise linear_vs_vanilla test confirms this. The only place where two pass and fused differ is NVFP4 with RHT + SR. This activates a separate columnwise RNG (need_separate_columnwise_rng), and RNG stream consumed differently. see comment in _backward_not_bitwise_comparable(). Removed the comment.

negvet and others added 6 commits June 1, 2026 08:47
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Signed-off-by: Evgeny <etsykunov@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants